Team Information¶
This project is a collaborative effort by a group of passionate data science enthusiasts working together to solve the challenges of autonomous driving using state-of-the-art machine learning techniques.
Team Name: Red Coders¶
Team Members and Roles¶
| Name | Role | Contributions |
|---|---|---|
| Bhushan Asati | Data Scientist | Data preprocessing, model training, evaluation, and deployment. |
| Rujuta Dabke | Data Scientist | Model fine-tuning, hyperparameter optimization, and visualization. |
| Suyash Madhavi | Data Scientist | Dataset integration, system architecture, and backend scripting. |
| Anirudh Sharma | Data Scientist | Streamlit app development, UI/UX design, and deployment. |
Individual Contributions¶
1. Bhushan Asati¶
- Led the data preprocessing pipeline by mapping raw labels to simplified categories (
HumanandVehicle). - Trained multiple models, including DenseNet121 and ResNet50.
- Contributed to deployment by integrating TensorFlow Serving with Streamlit.
2. Rujuta Dabke¶
- Focused on hyperparameter tuning to achieve high accuracy for DenseNet121.
- Created interactive visualizations for results, including confusion matrices and performance metrics.
- Documented model evaluation results and insights.
3. Suyash Madhavi¶
- Designed and implemented the dataset loading mechanism from Kaggle.
- Scripted Python modules for preprocessing, training, and evaluation.
- Managed GitHub repository and ensured version control for the project.
4. Anirudh Sharma¶
- Developed the frontend interface using Streamlit.
- Designed user-friendly elements, including image upload, confidence visualizations, and slideshows.
- Optimized the app for scalability and deployed it on Docker.
Collaboration Tools¶
- Version Control: GitHub for repository management and collaboration.
- Task Management: Trello for tracking tasks and deadlines.
- Communication: Slack and Zoom for team meetings and discussions.
- Documentation: Jupyter Notebook for project documentation and experimentation.
Acknowledgment¶
We would like to thank our professor and peers for their support and guidance throughout this project. The teamwork and collaboration have made this a fulfilling and enriching learning experience.
Multiclass Object Classification in Autonomous Driving¶
This project is part of an initiative to classify objects in autonomous driving scenarios into two categories:
- Human: Including pedestrians, cyclists, etc.
- Vehicle: Cars, trucks, vans, etc.
Using the KITTI Dataset, we implemented a Deep Learning-based approach to perform real-time classification of objects to support autonomous vehicle navigation and decision-making.
Objectives¶
- Develop a robust object classification system tailored for autonomous driving.
- Train and fine-tune state-of-the-art deep learning models on the KITTI dataset.
- Optimize the system to classify objects into two major categories:
- Human: Pedestrians, cyclists, and other human-related objects.
- Vehicle: Cars, trucks, vans, and other automotive objects.
- Deploy the trained model using Streamlit to create an interactive and user-friendly application.
KITTI Dataset Overview¶
The KITTI 3D Object Detection Dataset is a widely-used benchmark dataset for autonomous driving research. It was created to facilitate the development and evaluation of perception systems for self-driving cars. The dataset provides high-quality labeled data for 2D and 3D object detection tasks.
Key Details¶
- Source: KITTI Vision Benchmark Suite
- Categories: Pedestrian, Cyclist, Car, Truck, Van, Tram, Misc, etc.
- Primary Use Case: Object Detection, Classification, Tracking, and 3D Bounding Box Estimation
- Number of Images: ~7,481 labeled images in the training set
- Sensors: Captured using stereo cameras, Velodyne LiDAR, and GPS/IMU sensors
- Resolution: Images are generally 1242x375 pixels
- Location: Karlsruhe, Germany (urban and rural environments)
- File Formats:
- Images:
.pngformat - Annotations:
.txtfiles containing bounding boxes
- Images:
Dataset Components¶
The dataset consists of the following files:
RGB Images:
- Located in the
image_2/directory. - Front-view camera images captured during driving sequences.
- Located in the
Object Labels:
- Located in the
label_2/directory. - Text files containing detailed annotations for each image.
- Located in the
Calibration Data:
- Located in the
calib/directory. - Provides intrinsic and extrinsic calibration matrices for the camera and LiDAR.
- Located in the
Velodyne Data:
- Raw point clouds captured by the Velodyne LiDAR sensor.
- Can be used for 3D object detection but not used in this project.
Labels and Annotations¶
Each label file contains:
- Object type:
Car,Truck,Pedestrian,Cyclist, etc. - Bounding box coordinates (2D and optionally 3D).
- Object dimensions: Height, width, and length (for 3D bounding boxes).
- Truncation level: Degree to which an object is truncated by the image border.
- Occlusion level: Visibility of the object (0 = fully visible, 1 = partly occluded, 2 = largely occluded).
Column Details¶
The KITTI 3D Object Detection Dataset provides annotations in text files with detailed information for each object in an image. These annotations include 15 attributes (columns), which are explained below:
| Column Number | Attribute Name | Description | Example Value |
|---|---|---|---|
| 1 | Type | The object class. Can be one of the following: Car, Van, Truck, Pedestrian, Person_sitting, Cyclist, Tram, Misc, or DontCare. |
Car, Pedestrian |
| 2 | Truncation | Level of object truncation. Indicates how much of the object is cut off at the image boundary. Ranges from 0 (not truncated) to 1 (completely truncated). | 0.00, 0.50 |
| 3 | Occlusion | Level of object occlusion. Ranges from: 0: Fully visible 1: Partly occluded 2: Largely occluded 3: Unknown. |
0, 1 |
| 4 | Alpha | Observation angle of the object in the image plane, ranging from -π to π. Represents the object’s orientation relative to the camera. |
-1.82, 0.20 |
| 5-8 | 2D Bounding Box (xmin, ymin, xmax, ymax) | The coordinates of the 2D bounding box that surrounds the object in the image. These values are in pixels. | 599.41 156.40 629.75 189.25 |
| 9 | Height | Height of the object in 3D space, measured in meters. | 1.89 |
| 10 | Width | Width of the object in 3D space, measured in meters. | 1.47 |
| 11 | Length | Length of the object in 3D space, measured in meters. | 4.02 |
| 12-14 | 3D Location (x, y, z) | 3D coordinates of the object’s center in the camera coordinate system, measured in meters. - x: Left/right - y: Down/up - z: Forward/backward |
-1.47 1.90 46.74 |
| 15 | Rotation_y | Rotation angle of the object around the vertical (y) axis, in radians. Indicates the object's orientation relative to the camera. | -0.20 |
Column Explanations¶
Type:
- The main object classification label.
- Classes include
Car,Cyclist,Pedestrian, etc. DontCareobjects are ignored during evaluation and training.
Truncation:
- Indicates how much of the object is cut off at the image edges.
- Useful for determining whether an object is fully visible or partially outside the frame.
Occlusion:
- Captures the visibility of the object.
- A value of
0means the object is fully visible, while2indicates heavy occlusion.
Alpha:
- The angle of the object’s orientation in the image plane.
- Helps in determining the direction the object is facing.
2D Bounding Box:
- Specifies the rectangular region in the image where the object is located.
- Useful for tasks like 2D object detection.
Height, Width, Length:
- The physical dimensions of the object in the real world (3D space).
3D Location:
- The object’s position relative to the camera.
- Measured in meters in the camera coordinate system.
Rotation_y:
- The object’s orientation in the 3D space, specifically its rotation around the vertical axis.
- Useful for tasks like 3D object detection and localization.
Example of an Annotation Row¶
Here’s an example of a label file (000001.txt) from the dataset:
- Car 0.00 0 -1.82 599.41 156.40 629.75 189.25 1.89 1.47 4.02 -1.47 1.90 46.74 -0.20
Column Breakdown:¶
- Type:
Car - Truncation:
0.00(Not truncated) - Occlusion:
0(Fully visible) - Alpha:
-1.82 - 2D Bounding Box:
(xmin, ymin, xmax, ymax) = (599.41, 156.40, 629.75, 189.25) - Height:
1.89m - Width:
1.47m - Length:
4.02m - 3D Location:
(x, y, z) = (-1.47, 1.90, 46.74) - Rotation_y:
-0.20
Practical Use of Columns in the Project¶
2D Bounding Box (Columns 5-8):
- Used for localizing objects in the image for 2D object detection.
Type (Column 1):
- Simplified into two classes (
HumanandVehicle) for this project.
- Simplified into two classes (
Truncation and Occlusion (Columns 2-3):
- Not directly used in training but useful for data analysis and visualization.
3D Location (Columns 12-14):
- Helps in future extensions like 3D object detection or depth estimation.
Rotation_y (Column 15):
- Important for orientation-specific tasks in autonomous driving.
Preprocessed Categories¶
For this project, we simplified the dataset into two primary categories:
- Human:
- Includes
Pedestrian,Cyclist, and other human-related labels.
- Includes
- Vehicle:
- Includes
Car,Truck,Van, andTram.
- Includes
Objects labeled as Misc or DontCare were excluded to ensure focus on relevant classes.
Use Cases of the KITTI Dataset¶
Object Detection:
- Identify and classify objects in a scene.
3D Object Detection:
- Predict 3D bounding boxes for vehicles and pedestrians.
Object Tracking:
- Track object trajectories across multiple frames for video sequences.
Autonomous Driving Research:
- Train and evaluate perception systems for autonomous vehicles.
Advantages of KITTI Dataset¶
High Diversity:
- Contains images from urban, rural, and highway scenarios.
Rich Annotations:
- Provides detailed 2D and 3D annotations for comprehensive analysis.
Benchmark Dataset:
- Widely recognized and used in the autonomous driving research community.
Multi-Sensor Data:
- Combines camera, LiDAR, and GPS/IMU data for advanced applications.
Methodology¶
Data Preprocessing:
- Extract and preprocess images from the KITTI dataset.
- Normalize images to the range
[0, 1]. - Map object labels into the defined categories (
HumanandVehicle).
Model Selection:
- Fine-tuned DenseNet121 for classification.
- Compared with models like MobileNet50, Inception-V3, and ResNet50.
Training:
- Used TensorFlow/Keras for model training.
- Augmented training data to mitigate overfitting.
Evaluation:
- Measured model performance using metrics like accuracy, precision, recall, and F1-score.
- Evaluated on a balanced test set to address class imbalance.
Deployment:
- Created an interactive Streamlit app for end users.
- Integrated the model with TensorFlow Serving for scalable deployment.
Results¶
Model Performance Metrics¶
Metrics Table¶
| Metric | Human | Vehicle | Overall |
|---|---|---|---|
| Precision | 77% | 100% | 96% (weighted avg) |
| Recall | 99% | 95% | 96% (weighted avg) |
| F1-Score | 87% | 97% | 96% (weighted avg) |
| Accuracy | 96% |
Classification Report¶
| Class | Precision | Recall | F1-Score | Support |
|---|---|---|---|---|
| Human | 77% | 99% | 87% | 163 |
| Vehicle | 100% | 95% | 97% | 960 |
Confusion Matrix¶
| Predicted Human | Predicted Vehicle | |
|---|---|---|
| Actual Human | 161 | 2 |
| Actual Vehicle | 48 | 912 |
Key Observations¶
- The model achieved an overall accuracy of 97%, demonstrating high reliability.
- Precision for the "Vehicle" class is perfect (100%), indicating no false positives.
- The "Human" class has a slightly lower precision (85%) due to some misclassifications as "Vehicle."
- Recall for the "Human" class is very high (98%), which means the model identifies most humans correctly.
- Weighted averages (precision, recall, and F1-score) confirm the balanced performance across classes.
Deployment¶
The trained model was deployed using Streamlit, creating an intuitive web-based application:
Features:
- Upload an image to classify it as either
HumanorVehicle. - View predictions with confidence scores.
- Visualize predictions with interactive charts.
- Upload an image to classify it as either
Tools:
- Streamlit: For frontend development.
- TensorFlow Serving: For serving the trained model.
- Docker: For containerizing the application for scalability.
Streamlit App Workflow¶
- User uploads an image.
- The image is preprocessed and passed to the DenseNet121 model.
- The model predicts the class (
HumanorVehicle) and returns confidence scores. - Results are displayed with visualizations (e.g., bar chart, pie chart).
Challenges¶
Class Imbalance:
- Categories like cyclists were underrepresented in the dataset.
- Solution: Used weighted loss functions and oversampling for minority classes.
Model Overfitting:
- Deep learning models tended to overfit on the training set.
- Solution: Applied data augmentation and dropout layers during training.
Deployment Complexity:
- Integrating the trained model with a user-friendly interface.
- Solution: Used TensorFlow Serving and Streamlit for seamless deployment.
Real-Time Performance:
- Balancing accuracy and inference time for real-time predictions.
- Solution: Selected a lightweight model (DenseNet121) for faster inference.
Conclusion¶
This project successfully demonstrated the use of deep learning techniques for real-time object classification in autonomous driving systems:
- DenseNet121 achieved the highest accuracy (92.1%) among all models.
- The deployed Streamlit app provides an interactive interface for end users.
- The system is scalable and can be extended to include additional classes or real-time video streams.
Future Work¶
- Extend classification to include additional object categories (e.g., traffic signs, animals).
- Integrate real-time video processing for continuous object detection.
- Optimize model for edge devices using TensorFlow Lite.
- Address edge cases by collecting and training on additional datasets.
!pip install tensorflow_addons
Collecting tensorflow_addons Downloading tensorflow_addons-0.23.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.8 kB) Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from tensorflow_addons) (24.2) Collecting typeguard<3.0.0,>=2.7 (from tensorflow_addons) Downloading typeguard-2.13.3-py3-none-any.whl.metadata (3.6 kB) Downloading tensorflow_addons-0.23.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (611 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/611.8 kB ? eta -:--:-- ━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 163.8/611.8 kB 4.8 MB/s eta 0:00:01 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 604.2/611.8 kB 10.9 MB/s eta 0:00:01 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 611.8/611.8 kB 8.3 MB/s eta 0:00:00 Downloading typeguard-2.13.3-py3-none-any.whl (17 kB) Installing collected packages: typeguard, tensorflow_addons Successfully installed tensorflow_addons-0.23.0 typeguard-2.13.3
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
import pickle
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import DenseNet121
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from collections import Counter
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
from google.colab import files
files.upload()
{}
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!ls ~/.kaggle/
kaggle.json
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!mv ~/.kaggle/kaggle.json /root/.config/kaggle/
mv: cannot move '/root/.kaggle/kaggle.json' to '/root/.config/kaggle/': Not a directory
Download the KITTI dataset using Kaggle API
from kaggle.api.kaggle_api_extended import KaggleApi
api = KaggleApi()
api.authenticate()
dataset_name = 'garymk/kitti-3d-object-detection-dataset'
dataset_path = '/content/kitti_dataset/'
if not os.path.exists(dataset_path):
os.makedirs(dataset_path)
!kaggle datasets download -d {dataset_name} -p {dataset_path}
!unzip -q {dataset_path}kitti-3d-object-detection-dataset.zip -d {dataset_path}
print("Dataset downloaded and extracted successfully.")
else:
print("Dataset already exists. Skipping download.")
Dataset URL: https://www.kaggle.com/datasets/garymk/kitti-3d-object-detection-dataset License(s): unknown Downloading kitti-3d-object-detection-dataset.zip to /content/kitti_dataset 100% 30.0G/30.0G [04:09<00:00, 99.0MB/s] 100% 30.0G/30.0G [04:10<00:00, 129MB/s] Dataset downloaded and extracted successfully.
# Data Acquisition and Loading
# Define the paths to the data directories
image_dir = os.path.join(dataset_path, 'training/image_2/')
label_dir = os.path.join(dataset_path, 'training/label_2/')
# Get a list of all image files
image_files = [f for f in os.listdir(image_dir) if f.endswith('.png')]
dataset_path = '/content/kitti_dataset/'
image_dir = os.path.join(dataset_path, 'training/image_2/')
label_dir = os.path.join(dataset_path, 'training/label_2/')
MODEL_SAVE_PATH = '/content/final_densenet121_model.h5'
LABEL_ENCODER_PATH = 'label_encoder.pkl'
dataset_path = '/content/kitti_dataset/'
MODEL_SAVE_PATH = 'models/fine_tuned_densenet121_reduced_categories.h5'
TARGET_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 1e-4
# Ensure model directory exists
os.makedirs('models', exist_ok=True)
image_dir = os.path.join(dataset_path, 'training/image_2/')
label_dir = os.path.join(dataset_path, 'training/label_2/')
image_files = [f for f in os.listdir(image_dir) if f.endswith('.png')]
class_mapping = {
'Pedestrian': 'Human',
'Person_sitting': 'Human',
'Cyclist': 'Human',
'Car': 'Vehicle',
'Truck': 'Vehicle',
'Van': 'Vehicle',
'Tram': 'Vehicle',
'Misc': None, # Exclude Miscellaneous
'DontCare': None # Exclude DontCare
}
def map_label(label_path):
with open(label_path, 'r') as f:
lines = f.readlines()
mapped_classes = []
for line in lines:
data = line.strip().split()
obj_class = data[0]
mapped_class = class_mapping.get(obj_class)
if mapped_class is not None:
mapped_classes.append(mapped_class)
return mapped_classes
images = []
labels = []
for image_file in image_files:
image_path = os.path.join(image_dir, image_file)
label_path = os.path.join(label_dir, image_file.replace('.png', '.txt'))
image = cv2.imread(image_path)
image_resized = cv2.resize(image, TARGET_SIZE)
image_rgb = cv2.cvtColor(image_resized, cv2.COLOR_BGR2RGB)
images.append(image_rgb)
mapped_classes = map_label(label_path)
if mapped_classes and len(mapped_classes) > 0:
labels.append(mapped_classes[0])
else:
# If no valid classes found, skip this image
# or assign a default class if you prefer
# Here, we choose to skip the image
continue
X = np.array(images)
y = np.array(labels)
X = X / 255.0
# Encode labels
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)
class_names = list(label_encoder.classes_)
print("Reduced Categories:", class_names) # Should show ["Human", "Vehicle"]
Reduced Categories: ['Human', 'Vehicle']
import seaborn as sns
import matplotlib.pyplot as plt
from collections import Counter
# Assuming y_train contains the labels
label_counts = Counter(y_train)
plt.figure(figsize=(8, 5))
sns.barplot(x=list(label_counts.keys()), y=list(label_counts.values()), palette='viridis')
plt.title("Class Distribution in Training Data")
plt.xlabel("Class")
plt.ylabel("Count")
plt.xticks(ticks=range(len(class_names)), labels=class_names, rotation=45)
plt.show()
X_train, X_temp, y_train, y_temp = train_test_split(
X, y_encoded, test_size=0.3, random_state=42, stratify=y_encoded
)
X_val, X_test, y_val, y_test = train_test_split(
X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp
)
class_counts = Counter(y_train)
total_samples = len(y_train)
class_weights = {}
for class_idx, count in class_counts.items():
class_weights[class_idx] = total_samples / (len(class_names) * count)
print("Class Weights:", class_weights)
Class Weights: {1: 0.5847665847665847, 0: 3.449275362318841}
!pip install shap
Collecting shap Downloading shap-0.46.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (24 kB) Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from shap) (1.26.4) Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from shap) (1.13.1) Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (from shap) (1.5.2) Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from shap) (2.2.2) Requirement already satisfied: tqdm>=4.27.0 in /usr/local/lib/python3.10/dist-packages (from shap) (4.66.6) Requirement already satisfied: packaging>20.9 in /usr/local/lib/python3.10/dist-packages (from shap) (24.2) Collecting slicer==0.0.8 (from shap) Downloading slicer-0.0.8-py3-none-any.whl.metadata (4.0 kB) Requirement already satisfied: numba in /usr/local/lib/python3.10/dist-packages (from shap) (0.60.0) Requirement already satisfied: cloudpickle in /usr/local/lib/python3.10/dist-packages (from shap) (3.1.0) Requirement already satisfied: llvmlite<0.44,>=0.43.0dev0 in /usr/local/lib/python3.10/dist-packages (from numba->shap) (0.43.0) Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->shap) (2.9.0.post0) Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->shap) (2024.2) Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->shap) (2024.2) Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->shap) (1.4.2) Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->shap) (3.5.0) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->shap) (1.16.0) Downloading shap-0.46.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (540 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 540.1/540.1 kB 10.5 MB/s eta 0:00:00 Downloading slicer-0.0.8-py3-none-any.whl (15 kB) Installing collected packages: slicer, shap Successfully installed shap-0.46.0 slicer-0.0.8
augmented_images = next(train_gen)[0]
plt.figure(figsize=(10, 10))
for i in range(9):
plt.subplot(3, 3, i+1)
plt.imshow(augmented_images[i])
plt.axis('off')
plt.suptitle("Sample Augmented Images")
plt.show()
train_datagen = ImageDataGenerator(
rotation_range=15,
width_shift_range=0.1,
height_shift_range=0.1,
horizontal_flip=True,
zoom_range=0.1
)
test_datagen = ImageDataGenerator()
train_gen = train_datagen.flow(X_train, y_train, batch_size=BATCH_SIZE)
val_gen = test_datagen.flow(X_val, y_val, batch_size=BATCH_SIZE)
test_gen = test_datagen.flow(X_test, y_test, batch_size=BATCH_SIZE, shuffle=False)
densenet_base = DenseNet121(weights='imagenet', include_top=False, input_shape=(224,224,3))
densenet_base.trainable = True
for layer in densenet_base.layers[:-10]:
layer.trainable = False
model = models.Sequential([
densenet_base,
layers.GlobalAveragePooling2D(),
layers.Dense(256, activation='relu'),
layers.Dense(len(class_names), activation='softmax')
])
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/densenet/densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5
29084464/29084464 [==============================] - 0s 0us/step
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
densenet121 (Functional) (None, 7, 7, 1024) 7037504
global_average_pooling2d ( (None, 1024) 0
GlobalAveragePooling2D)
dense (Dense) (None, 256) 262400
dense_1 (Dense) (None, 2) 514
=================================================================
Total params: 7300418 (27.85 MB)
Trainable params: 431042 (1.64 MB)
Non-trainable params: 6869376 (26.20 MB)
_________________________________________________________________
history = model.fit(
train_gen,
validation_data=val_gen,
epochs=EPOCHS,
class_weight=class_weights
)
Epoch 1/10 164/164 [==============================] - 130s 759ms/step - loss: 0.3987 - accuracy: 0.8520 - val_loss: 0.1743 - val_accuracy: 0.9563 Epoch 2/10 164/164 [==============================] - 119s 726ms/step - loss: 0.2696 - accuracy: 0.9064 - val_loss: 0.1469 - val_accuracy: 0.9537 Epoch 3/10 164/164 [==============================] - 119s 727ms/step - loss: 0.2106 - accuracy: 0.9274 - val_loss: 0.1442 - val_accuracy: 0.9608 Epoch 4/10 164/164 [==============================] - 119s 724ms/step - loss: 0.1675 - accuracy: 0.9473 - val_loss: 0.1334 - val_accuracy: 0.9572 Epoch 5/10 164/164 [==============================] - 122s 742ms/step - loss: 0.1273 - accuracy: 0.9565 - val_loss: 0.1135 - val_accuracy: 0.9599 Epoch 6/10 164/164 [==============================] - 118s 716ms/step - loss: 0.1021 - accuracy: 0.9656 - val_loss: 0.0988 - val_accuracy: 0.9670 Epoch 7/10 164/164 [==============================] - 116s 707ms/step - loss: 0.0847 - accuracy: 0.9702 - val_loss: 0.1021 - val_accuracy: 0.9697 Epoch 8/10 164/164 [==============================] - 110s 671ms/step - loss: 0.0661 - accuracy: 0.9744 - val_loss: 0.1047 - val_accuracy: 0.9688 Epoch 9/10 164/164 [==============================] - 112s 679ms/step - loss: 0.0839 - accuracy: 0.9681 - val_loss: 0.0742 - val_accuracy: 0.9768 Epoch 10/10 164/164 [==============================] - 111s 677ms/step - loss: 0.0651 - accuracy: 0.9767 - val_loss: 0.1044 - val_accuracy: 0.9661
test_loss, test_acc = model.evaluate(test_gen)
print(f"Test Accuracy after fine-tuning with reduced categories: {test_acc*100:.2f}%")
36/36 [==============================] - 19s 529ms/step - loss: 0.0821 - accuracy: 0.9724 Test Accuracy after fine-tuning with reduced categories: 97.24%
y_pred_probs = model.predict(test_gen)
y_pred = np.argmax(y_pred_probs, axis=1)
36/36 [==============================] - 22s 541ms/step
# Loss
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title("Model Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.tight_layout()
plt.show()
# Accuracy
plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title("Model Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.tight_layout()
plt.show()
print("Classification Report:")
print(classification_report(y_test, y_pred, target_names=class_names))
Classification Report:
precision recall f1-score support
Human 0.85 0.98 0.91 163
Vehicle 1.00 0.97 0.98 960
accuracy 0.97 1123
macro avg 0.93 0.97 0.95 1123
weighted avg 0.98 0.97 0.97 1123
from sklearn.metrics import precision_score, recall_score, f1_score
import pandas as pd
precision = precision_score(y_test, y_pred, average=None)
recall = recall_score(y_test, y_pred, average=None)
f1 = f1_score(y_test, y_pred, average=None)
metrics_df = pd.DataFrame({
"Class": class_names,
"Precision": precision,
"Recall": recall,
"F1-Score": f1
})
metrics_df.plot(x="Class", kind="bar", figsize=(10, 6))
plt.title("Per-Class Performance Metrics")
plt.xlabel("Class")
plt.ylabel("Score")
plt.legend(loc="lower right")
plt.show()
import random
num_samples = 10
indices = random.sample(range(len(X_test)), num_samples)
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i, ax in enumerate(axes.flatten()):
idx = indices[i]
ax.imshow(X_test[idx])
true_label = class_names[y_test[idx]]
pred_label = class_names[y_pred[idx]]
ax.set_title(f"True: {true_label}\nPred: {pred_label}")
ax.axis('off')
plt.tight_layout()
plt.show()
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8,6))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names, cmap='Blues')
plt.title('Confusion Matrix after Reducing Categories')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.show()
confidence_scores = np.max(y_pred_probs, axis=1)
plt.figure(figsize=(8, 5))
sns.histplot(confidence_scores, bins=10, kde=True)
plt.title("Distribution of Confidence Scores")
plt.xlabel("Confidence Score")
plt.ylabel("Frequency")
plt.show()
misclassified_indices = np.where(y_test != y_pred)[0]
plt.figure(figsize=(10, 10))
for i, idx in enumerate(misclassified_indices[:9]):
plt.subplot(3, 3, i+1)
plt.imshow(X_test[idx])
true_label = class_names[y_test[idx]]
pred_label = class_names[y_pred[idx]]
plt.title(f"True: {true_label}\nPred: {pred_label}")
plt.axis('off')
plt.tight_layout()
plt.show()
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize
# Assuming y_test is one-hot encoded
# Fix: Ensure y_test_binarized has the correct shape for multi-class ROC
n_classes = len(class_names) # Get the number of classes
y_test_binarized = label_binarize(y_test, classes=range(n_classes))
# If y_test_binarized has only one column, reshape it
if y_test_binarized.shape[1] == 1 and n_classes > 1:
y_test_binarized = np.hstack((1 - y_test_binarized, y_test_binarized))
fpr = {}
tpr = {}
roc_auc = {}
plt.figure(figsize=(10, 8))
for i, class_name in enumerate(class_names):
fpr[i], tpr[i], _ = roc_curve(y_test_binarized[:, i], y_pred_probs[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
plt.plot(fpr[i], tpr[i], label=f"{class_name} (AUC = {roc_auc[i]:.2f})")
plt.plot([0, 1], [0, 1], "k--", label="Random")
plt.title("ROC Curves by Class")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend(loc="lower right")
plt.show()
import time
inference_times = []
for i in range(50): # Test 50 samples
start = time.time()
model.predict(np.expand_dims(X_test[i], axis=0))
inference_times.append(time.time() - start)
plt.figure(figsize=(8, 5))
sns.histplot(inference_times, bins=10, kde=True)
plt.title("Distribution of Inference Times")
plt.xlabel("Time (seconds)")
plt.ylabel("Frequency")
plt.show()
1/1 [==============================] - 2s 2s/step 1/1 [==============================] - 0s 100ms/step 1/1 [==============================] - 0s 96ms/step 1/1 [==============================] - 0s 93ms/step 1/1 [==============================] - 0s 93ms/step 1/1 [==============================] - 0s 96ms/step 1/1 [==============================] - 0s 97ms/step 1/1 [==============================] - 0s 100ms/step 1/1 [==============================] - 0s 101ms/step 1/1 [==============================] - 0s 96ms/step 1/1 [==============================] - 0s 97ms/step 1/1 [==============================] - 0s 95ms/step 1/1 [==============================] - 0s 97ms/step 1/1 [==============================] - 0s 97ms/step 1/1 [==============================] - 0s 96ms/step 1/1 [==============================] - 0s 93ms/step 1/1 [==============================] - 0s 100ms/step 1/1 [==============================] - 0s 100ms/step 1/1 [==============================] - 0s 94ms/step 1/1 [==============================] - 0s 93ms/step 1/1 [==============================] - 0s 100ms/step 1/1 [==============================] - 0s 94ms/step 1/1 [==============================] - 0s 96ms/step 1/1 [==============================] - 0s 94ms/step 1/1 [==============================] - 0s 91ms/step 1/1 [==============================] - 0s 97ms/step 1/1 [==============================] - 0s 91ms/step 1/1 [==============================] - 0s 92ms/step 1/1 [==============================] - 0s 98ms/step 1/1 [==============================] - 0s 98ms/step 1/1 [==============================] - 0s 143ms/step 1/1 [==============================] - 0s 96ms/step 1/1 [==============================] - 0s 91ms/step 1/1 [==============================] - 0s 97ms/step 1/1 [==============================] - 0s 92ms/step 1/1 [==============================] - 0s 90ms/step 1/1 [==============================] - 0s 99ms/step 1/1 [==============================] - 0s 97ms/step 1/1 [==============================] - 0s 98ms/step 1/1 [==============================] - 0s 95ms/step 1/1 [==============================] - 0s 102ms/step 1/1 [==============================] - 0s 102ms/step 1/1 [==============================] - 0s 101ms/step 1/1 [==============================] - 0s 96ms/step 1/1 [==============================] - 0s 93ms/step 1/1 [==============================] - 0s 100ms/step 1/1 [==============================] - 0s 94ms/step 1/1 [==============================] - 0s 103ms/step 1/1 [==============================] - 0s 96ms/step 1/1 [==============================] - 0s 95ms/step